 # -*- coding: utf-8 -*-
"""
Created on Fri Sep 17 10:29:29 2021
计算出制备到 Bell 态的 测试集 保真度分布
@author: Waikikilick
"""

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #设置 tf 的报警等级
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses,initializers,Sequential,metrics,models
import copy 
from collections import deque
import random
from scipy.linalg import expm
from time import *

tf.random.set_seed(1)
np.random.seed(1)
random.seed (1)     

class env(object):
    def  __init__(self, 
        dt = np.pi / 2,
        noise_a = 0,
        ): 
        self.action_space = np.array(  [[1,1],
                                        [1,2],
                                        [1,3],
                                        [1,4],
                                        [1,5],
                                        [2,1],
                                        [2,2],
                                        [2,3],
                                        [2,4],
                                        [2,5],
                                        [3,1],
                                        [3,2],
                                        [3,3],
                                        [3,4],
                                        [3,5],
                                        [4,1],
                                        [4,2],
                                        [4,3],
                                        [4,4],
                                        [4,5],
                                        [5,1],
                                        [5,2],
                                        [5,3],
                                        [5,4],
                                        [5,5]] )
        self.n_actions = len(self.action_space)
        self.n_features = 8 #描述状态所用的长度
        self.target_psi =  np.mat([[1], [0], [0], [1]], dtype=complex)/np.sqrt(2) #最终的目标态为 |0>,  np.array([1,0,0,0])
        self.h_1 = 1
        self.h_2 = 1
        self.I = np.matrix(np.identity(2, dtype=complex))
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
                
    def psi_set(self):
        
        alpha_num = 4
        
        theta = [np.pi/8,np.pi/4,3*np.pi/8]
        theta_1 = theta
        theta_2 = theta
        theta_3 = theta
        
        alpha = np.linspace(0,np.pi*2,alpha_num,endpoint=False)
        alpha_1 = alpha
        alpha_2 = alpha
        alpha_3 = alpha
        alpha_4 = alpha
        
        psi_set = []#np.matrix([[0,0,0,0]],dtype=complex) #第一行用来占位，否则无法和其他行并在一起，在最后要注意去掉这一行
        for ii in range(3): #theta_1
            for jj in range(3): #theta_2
                for kk in range(3): #theta_3
                    for mm in range(alpha_num): #alpha_1
                        for nn in range(alpha_num): #alpha_2
                            for oo in range(alpha_num): #alpha_3
                                for pp in range(alpha_num): #alpha_4
                                    
                                    a_1_mo = np.cos(theta_1[ii])
                                    a_2_mo = np.sin(theta_1[ii])*np.cos(theta_2[jj])
                                    a_3_mo = np.sin(theta_1[ii])*np.sin(theta_2[jj])*np.cos(theta_3[kk])
                                    a_4_mo = np.sin(theta_1[ii])*np.sin(theta_2[jj])*np.sin(theta_3[kk])
                                    
                                    a_1_real = a_1_mo*np.cos(alpha_1[mm])
                                    a_1_imag = a_1_mo*np.sin(alpha_1[mm])
                                    a_2_real = a_2_mo*np.cos(alpha_2[nn])
                                    a_2_imag = a_2_mo*np.sin(alpha_2[nn])
                                    a_3_real = a_3_mo*np.cos(alpha_3[oo])
                                    a_3_imag = a_3_mo*np.sin(alpha_3[oo])
                                    a_4_real = a_4_mo*np.cos(alpha_4[pp])
                                    a_4_imag = a_4_mo*np.sin(alpha_4[pp])
                                    
                                    a_1_complex = a_1_real + a_1_imag*1j
                                    a_2_complex = a_2_real + a_2_imag*1j
                                    a_3_complex = a_3_real + a_3_imag*1j
                                    a_4_complex = a_4_real + a_4_imag*1j
                                    
                                    a_complex = np.mat([[ a_1_complex], [a_2_complex], [a_3_complex], [a_4_complex]])
                                    # psi_set = np.row_stack((psi_set,a_complex))
                                    psi_set.append(a_complex)
                                    
        # psi_set = np.array(np.delete(psi_set,0,axis=0)) # 删除矩阵的第一行
        random.shuffle(psi_set) #打乱顺序
    
        training_set = psi_set[0:256]
        validation_set = psi_set[256:512]
        testing_set = psi_set[512:]
        
        return training_set, validation_set, testing_set
        
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array(init_psi.real.tolist() + init_psi.imag.tolist()) # 实向量形式
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    
    def step(self, state, action, nstep):
        
        psi = np.mat((np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j]).T).squeeze(0))  # 从 实向量 变回 复矩阵 形式
        #matrix([[ 1实 + 1虚j],
        #        [ 2实 + 2虚j]])
        
        J_1, J_2 =  self.action_space[action,0], self.action_space[action,1]  # control field strength
        J_12 = J_1 * J_2 /2
        
        H =  (J_1*np.kron(self.s_z, self.I) + J_2*np.kron(self.I, self.s_z) + \
                        J_12/2*np.kron((self.s_z-self.I),(self.s_z-self.I)) + \
           self.h_1*np.kron(self.s_x,self.I) + self.h_2*np.kron(self.I,self.s_x))/2
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        err = 10e-4
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        
        done = (((1 - fid) < err) or nstep >= 20 * np.pi / self.dt)  
    
        state = np.array(psi.real.tolist() + psi.imag.tolist()) # 实向量形式

        return state, done, fid      

def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent_matrix(x):
    out = np.mat(x).T * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    out = out * net[6].numpy() + net[7].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action

def testing_fid_distribution(): # 测试 测试集中的点 得到 保真度 分布
    print('\n保真度分布计算中, 请稍等...')
    
    fid_dis_list = []
    
    testing_set = env.testing_set
    fid_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        while True:
            action = agent_matrix(observation) 
            observation, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_max = max(fid_max, fid)
            if done:
                break
        fid_dis_list.append(fid_max)
        
    return fid_dis_list
       
def fid_sort(fid_dis_list):
    f00 = 0
    f50 = 0
    f51 = 0
    f52 = 0
    f53 = 0
    f54 = 0
    f55 = 0
    f56 = 0
    f57 = 0
    f58 = 0
    f59 = 0
    f60 = 0
    f61 = 0
    f62 = 0
    f63 = 0
    f64 = 0
    f65 = 0
    f66 = 0
    f67 = 0
    f68 = 0
    f69 = 0
    f70 = 0
    f71 = 0
    f72 = 0
    f73 = 0
    f74 = 0
    f75 = 0
    f76 = 0
    f77 = 0
    f78 = 0
    f79 = 0
    f80 = 0
    f81 = 0
    f82 = 0
    f83 = 0
    f84 = 0
    f85 = 0
    f86 = 0
    f87 = 0
    f88 = 0
    f89 = 0
    f90 = 0
    f91 = 0
    f92 = 0
    f93 = 0
    f94 = 0
    f95 = 0
    f96 = 0
    f97 = 0
    f98 = 0
    f99 = 0
    f100 = 0
    
    for item in fid_dis_list:
        if item >= 0.995:
            f100 += 1
        elif item >= 0.985:
            f99 += 1
        elif item >= 0.975:
            f98 += 1
        elif item >= 0.965:
            f97 += 1        
        elif item >= 0.955:
            f96 += 1         
        elif item >= 0.945:
            f95 += 1         
        elif item >= 0.935:
            f94 += 1         
        elif item >= 0.925:
            f93 += 1         
        elif item >= 0.915:
            f92 += 1         
        elif item >= 0.905:
            f91 += 1         
        elif item >= 0.895:
            f90 += 1         
        elif item >= 0.885:
            f89 += 1         
        elif item >= 0.875:
            f88 += 1        
        elif item >= 0.865:
            f87 += 1         
        elif item >= 0.855:
            f86 += 1         
        elif item >= 0.845:
            f85 += 1         
        elif item >= 0.835:
            f84 += 1         
        elif item >= 0.825:
            f83 += 1         
        elif item >= 0.815:
            f82 += 1         
        elif item >= 0.805:
            f81 += 1         
        elif item >= 0.795:
            f80 += 1         
        elif item >= 0.785:
            f79 += 1         
        elif item >= 0.775:
            f78 += 1         
        elif item >= 0.765:
            f77 += 1         
        elif item >= 0.755:
            f76 += 1         
        elif item >= 0.745:
            f75 += 1         
        elif item >= 0.735:
            f74 += 1         
        elif item >= 0.725:
            f73 += 1         
        elif item >= 0.715:
            f72 += 1         
        elif item >= 0.705:
            f71 += 1         
        elif item >= 0.695:
            f70 += 1
        elif item >= 0.685:
            f69 += 1
        elif item >= 0.675:
            f68 += 1        
        elif item >= 0.665:
            f67 += 1        
        elif item >= 0.655:
            f66 += 1        
        elif item >= 0.645:
            f65 += 1        
        elif item >= 0.635:
            f64 += 1        
        elif item >= 0.625:
            f63 += 1        
        elif item >= 0.615:
            f62 += 1        
        elif item >= 0.605:
            f61 += 1        
        elif item >= 0.595:
            f60 += 1        
        elif item >= 0.585:
            f59 += 1        
        elif item >= 0.575:
            f58 += 1          
        elif item >= 0.565:
            f57 += 1          
        elif item >= 0.555:
            f56 += 1          
        elif item >= 0.545:
            f55 += 1          
        elif item >= 0.535:
            f54 += 1          
        elif item >= 0.525:
            f53 += 1          
        elif item >= 0.515:
            f52 += 1          
        elif item >= 0.505:
            f51 += 1          
        elif item >= 0.495:
            f50 += 1  
        elif item >= 0:
            f00 += 1
    
    print('0.00-0.49: ',f00)
    print('0.50-0.54: ',f50,f51,f52,f53,f54)
    print('0.55-0.59: ',f55,f56,f57,f58,f59)
    print('0.60-0.64: ',f60,f61,f62,f63,f64)
    print('0.65-0.69: ',f65,f66,f67,f68,f69)
    print('0.70-0.74: ',f70,f71,f72,f73,f74)
    print('0.75-0.79: ',f75,f76,f77,f78,f79)
    print('0.80-0.84: ',f80,f81,f82,f83,f84)
    print('0.85-0.89: ',f85,f86,f87,f88,f89)
    print('0.90-0.94: ',f90,f91,f92,f93,f94)
    print('0.95-1.00: ',f95,f96,f97,f98,f99,f100)
  
if __name__ == "__main__":
    
    dt = np.pi/2
    
    network = keras.models.load_model('dqn_tf2_two_qubit_to_Bell_saved_model') #导入训练好的网络 
    net = network.trainable_variables
    
    env = env(dt = dt)
    
    testing_set = env.testing_set          
    
    # 测试
    fid_dis_list = testing_fid_distribution()
    
    fid_sort(fid_dis_list)

# 保真度分布计算中, 请稍等...
    # 0.00-0.49:  0
    # 0.50-0.54:  0 0 0 0 0
    # 0.55-0.59:  0 0 0 0 0
    # 0.60-0.64:  0 0 0 0 0
    # 0.65-0.69:  0 0 0 0 0
    # 0.70-0.74:  0 0 0 0 0
    # 0.75-0.79:  0 0 0 0 0
    # 0.80-0.84:  0 0 0 0 0
    # 0.85-0.89:  0 0 0 0 0
    # 0.90-0.94:  0 7 15 106 302
    # 0.95-1.00:  679 1029 1588 1715 920 39